from Network.network import Network
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import copy, time
from Network.network_utils import reduce_function, get_acti, pytorch_model, cuda_string
from Network.General.Flat.mlp import MLPNetwork
from Network.General.Conv.conv import ConvNetwork
from Network.General.Factor.Attention.attn_utils import evaluate_key_query, mask_query, init_key_query_args, init_final_args
from Network.General.Factor.Attention.base_attention import BaseMaskedAttentionNetwork

class MultiHeadAttentionParallelLayer(Network):
    def __init__(self, args):
        super().__init__(args)
        self.softmax =  nn.Softmax(-1)
        self.fp = args.factor
        self.ap = args.mask_attn
        self.append_keys = args.factor_net.append_keys
        self.key_dim = args.embed_dim # the dimension of the key inputs, must equal model_dim * num_heads
        self.query_dim = args.embed_dim
        self.num_heads = args.mask_attn.num_heads
        self.merge_function = args.mask_attn.merge_function
        self.model_dim = args.mask_attn.model_dim # the dimension of the keys and queries after network
        self.append_mask = args.factor_net.append_mask
        self.append_broadcast_mask = args.factor_net.append_broadcast_mask
        
        # process all keys at once
        key_query_args = init_key_query_args(args)
        self.key_network = ConvNetwork(key_query_args)
        self.query_network = ConvNetwork(key_query_args)

        value_args = init_key_query_args(args, use_broadcast_mask=True) # only differs if append_keys
        value_args.object_dim = (args.embed_dim + int(self.append_keys) * args.embed_dim
                                  + int(self.append_mask) * self.fp.num_objects
                                  + self.append_broadcast_mask)
        self.value_network = ConvNetwork(value_args) # values append keys internally
        
        final_args = init_final_args(args)
        self.final_network = ConvNetwork(final_args)

        self.model = [self.key_network, self.query_network, self.value_network, self.final_network]
    
    def append_values(self, keys, queries, mask): # TODO: modify so that we can apply feature masks here
        # appends the keys to the queries, expanding the values to batch x keys x queries x (key_dim + query_dim)
        if self.append_broadcast_mask:
            broadcast_mask = 1-mask.unsqueeze(-1).broadcast_to(mask.shape[0], mask.shape[1], mask.shape[2], self.append_broadcast_mask)
        if self.append_keys:
            values = list()
            for i in range(keys.shape[1]):
                key = keys[:,i]
                if self.append_mask: key = torch.cat([key, mask[:,i]], dim=-1)
                key = torch.broadcast_to(key.unsqueeze(1),  (key.shape[0], queries.shape[1], key.shape[-1]))
                if self.append_broadcast_mask: values.append(torch.cat([key, queries, broadcast_mask[:, i]], dim=-1))
                else: values.append(torch.cat([key, queries], dim=-1))
            return torch.stack(values, dim = 1)
        else: 
            if self.append_broadcast_mask: return torch.stack([torch.cat([queries.clone(), broadcast_mask[:,i]], dim=-1) for i in range(keys.shape[1])], dim=1)
            return torch.stack([queries.clone() for _ in range(keys.shape[1])], dim=1) 

    def forward(self, keys, queries, mask, query_final=False):
        # keys of shape: batch x num_keys x key_dim
        # queries of shape: batch x num_queries x query_dim
        # mask and valid shape: batch x num_keys x num_queries x 1
        # query final returns batch x num_keys x num_queries x model dim, otherwise batch x num_keys x model_dim

        start = time.time()
        embed=keys # if this is a query_final, the last embed is the keys
        batch_size, num_keys, num_queries = keys.shape[0], keys.shape[1], queries.shape[1]
        value_inputs = self.append_values(keys, queries, mask)
        # batch, num_keys, embed_dim -> batch, num_keys, model_dim * num_heads
        keys = self.key_network(keys.transpose(-2,-1)).transpose(-2,-1)
        # batch, num_queries, embed_dim -> batch, num_queries, model_dim * num_heads
        queries = self.query_network(queries.transpose(-2,-1)).transpose(-2,-1)
        # batch, num_keys, model_dim * num_heads -> batch, num_heads, model_dim, num_keys
        # batch, num_queries, model_dim * num_heads -> batch, num_heads, num_queries, model_dim
        keys, queries = keys.reshape(batch_size, -1, self.num_heads, self.model_dim).transpose(1,2).transpose(2,3), queries.reshape(batch_size, -1, self.num_heads, self.model_dim).transpose(1,2)

        # batch, num_keys * num_queries, embed_dim * 2 -> batch, num_keys * num_queries, num_heads * model_dim
        values = self.value_network(value_inputs.reshape(batch_size, num_keys * num_queries, -1).transpose(-1,-2)).transpose(-1,-2) # batch x num_keys * num_queries x num_heads * model_dim
        # -> batch, num_heads, num_keys * num_queries, model_dim
        values = values.reshape(batch_size, num_keys * num_queries, self.num_heads, self.model_dim).transpose(1,2)
        # -> batch x num_heads x num_keys x num_queries x model_dim
        values = values.transpose(-1,-2).reshape(batch_size, self.num_heads, self.model_dim, num_keys, num_queries).transpose(2,3).transpose(3,4)
        
        # get the weights using the keys and queries: batch, heads, keys, queries
        weights = evaluate_key_query(self.softmax, keys, queries, mask, single_key=False, gumbel=self.ap.gumbel_attention, renormalize=self.ap.renormalize, use_sigmoid =query_final and self.ap.bernoulli_weights) # batch x heads x keys x queries
        # batch, heads,  keys, queries, 1 * batch, heads, keys, queries, model_dim = batch x heads x keys x queries x model_dim
        # gets the values reweighted, but NOT softmaxed over
        values = (weights.unsqueeze(-1) * values )
    
        if query_final: # preserve the queries by not performing the final network
            if self.ap.merge_function == 'cat': values = values.transpose(4,3).transpose(3,2)
            values = reduce_function(self.ap.merge_function, values, dim=1) # batch x keys x queries x model_dim (merges the heads)
            if self.ap.merge_function == 'cat': values = values.transpose(1,2).transpose(2,3)
        else: # sum along the query dimension (already normalized by weights) and apply the final network
            values = values.sum(dim=-2) # batch x heads x keys x model_dim
            # cat requires the heads and value dimension to get concatenated
            if self.ap.merge_function == "cat": values = values = reduce_function(self.ap.merge_function, values.transpose(1,2), dim=2)
            else: values = reduce_function(self.ap.merge_function, values, dim=1)
            values = self.final_network(values.transpose(-2,-1)).transpose(-2,-1) # batch x keys x model_dim
            embed = values
        return values, weights, embed

class ParallelMaskedAttentionNetwork(BaseMaskedAttentionNetwork):
    def __init__(self, args):
        '''
        Performs masking attention where masking happens at the attention weights,
        which allows for more key parallelism compared to masking the embeddings,
        but might perform less well
        TODO: logic for this part is shared with mask_attention, merge, also shared with
        key_pair, though merging that is more challenging
        '''
        super().__init__(args, MultiHeadAttentionParallelLayer)
        self.fp = args.factor
        self.train()
        self.reset_network_parameters()

    def compute_attention(self, key, query, mask):
        return self.multi_head_attention(key, query, mask) # batch, key, queries, embed_dim